import torch

a = torch.tensor([
    [[1],[2],[3]],
    [[3],[4],[5]]
])
print(a.shape)
print(torch.squeeze(a,2).shape)